#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020, 2024, Niklas Hauser
#
# This file is part of the modm project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# -----------------------------------------------------------------------------

from collections import defaultdict

def common_usb_irqs(env):
    """
    Filters the USB interrupts by port:

      - `usb_irqs`: All unfiltered USB interrupts
      - `port_irqs`: Filtered USB interrupts categorized by port: `fs` or `hs`.
      - `is_remap`: True when `port_irqs` contains remapped IRQs (specific to STM32F3)

    Interrupts used by USB FS:

     - OTG_FS (no suffix)
     - USB
     - USBWakeUp (no suffix)
     - USB_LP* (any suffix)
     - USB_HP* (any suffix)

    Interrupts used by USB HS:

     - OTG_HS (no suffix)

    :returns: a dictionary of USB interrupt properties
    """
    usb_vectors = {v["name"] for v in env[":target"].get_driver("core")["vector"]
                   if any(v["name"].startswith(n) for n in ["USB", "OTG"])}
    is_remap = any("_RMP" in v for v in usb_vectors)

    port_irqs = defaultdict(list)
    port_irqs_remap = defaultdict(list)
    for vector in usb_vectors:
        port = "hs" if "_HS" in vector else "fs"
        if is_remap and vector in ["USBWakeUp_RMP", "USB_LP", "USB_HP"]:
            port_irqs_remap[port].append(vector)
        else:
            port_irqs[port].append(vector)

    port_irqs = {k:v for k,v in port_irqs.items()}
    irqs = {
        "usb_irqs": usb_vectors,
        "is_remap": is_remap,
        "port_irqs": port_irqs_remap if is_remap else port_irqs,
    }
    return irqs

def generate_instance(env, port, otg=False):
    env.outbasepath = "modm/src/modm/platform/usb"
    irq_data = env.query(":platform:usb:irqs")
    all_signals = env.query(":platform:gpio:all_signals")
    ulpi_signals = {"Ulpick", "Ulpistp", "Ulpidir", "Ulpinxt",
                    "Ulpid0", "Ulpid1", "Ulpid2", "Ulpid3",
                    "Ulpid4", "Ulpid5", "Ulpid6", "Ulpid7"}
    is_ulpi = len(ulpi_signals & set(all_signals)) == len(ulpi_signals)
    env.substitutions = {
        "port": port,
        "peripheral": "Usbotg{}".format(port) if otg else "Usb",
        "is_otg": otg,
        "is_remap": irq_data["is_remap"],
        "irqs": irq_data["port_irqs"][port],
        "target": env[":target"].identifier,
        "is_ulpi": port == "hs" and is_ulpi,
    }
    if otg:
        env.template("usb.hpp.in", "usb_{}.hpp".format(port))
    else:
        env.template("usb.hpp.in")


# -----------------------------------------------------------------------------
def init(module):
    module.name = ":platform:usb"
    module.description = "Universal Serial Bus (USB)"

def prepare(module, options):
    device = options[":target"]
    if not (device.has_driver("usb:stm32*") or
            device.has_driver("usb_otg_fs:stm32*") or
            device.has_driver("usb_otg_hs:stm32*")):
        return False

    if device.has_driver("usb_otg_fs"):
        module.add_submodule(UsbInstance("fs"))
    if device.has_driver("usb_otg_hs"):
        module.add_submodule(UsbInstance("hs"))

    module.depends(
        ":architecture:interrupt",
        ":cmsis:device",
        ":platform:gpio",
        ":platform:rcc")

    module.add_query(EnvironmentQuery(name="irqs", factory=common_usb_irqs))

    return True

def build(env):
    env.outbasepath = "modm/src/modm/platform/usb"
    if not (env[":target"].has_driver("usb_otg_fs") or
            env[":target"].has_driver("usb_otg_hs")):
        generate_instance(env, "fs")
    if env.has_module(":tinyusb:device:cdc"):
        env.copy(repopath("ext/hathach/uart.hpp"), "uart.hpp")


# -----------------------------------------------------------------------------
class UsbInstance(Module):
    def __init__(self, speed):
        self.speed = speed

    def init(self, module):
        module.name = self.speed
        module.description = "{} Speed".format("Full" if self.speed == "fs" else "High")

    def prepare(self, module, options):
        module.depends(":platform:gpio")
        return True

    def build(self, env):
        generate_instance(env, port=self.speed, otg=True)
